the decision blog

distinguishing birds from airplanes with a pytorch convnet

Written on

This post, in which we'll create a convolutional net to discern birds from airplanes, is inspired by a couple of chapters from the book Deep Learning with PyTorch, by Stevens, Antiga, and Viehmann. Here are some of the images that we'll be dealing with:

birds and planes

First let's load up the data and normalize it.

from torchvision import datasets, transforms
from torch import optim
import torch
import sys
cifar = datasets.CIFAR10('.',train=True,download=True, \
                         transform=transforms.ToTensor())
cifar_val = datasets.CIFAR10('.',train=False,download=True, \
                             transform=transforms.ToTensor())

# normalize the data
imgs     = torch.stack([img_t for img_t, _ in cifar],dim=3)
imgs_val = torch.stack([img_t for img_t, _ in cifar_val],dim=3)
mean     = imgs.view(3,-1).mean(dim=1)
mean_val = imgs_val.view(3,-1).mean(dim=1)
std      = imgs.view(3,-1).std(dim=1)
std_val  = imgs_val.view(3,-1).std(dim=1)
cifar    = datasets.CIFAR10('.',train=True,download=False, \
                            transform=transforms.Compose([ \
                            transforms.ToTensor(), \
                            transforms.Normalize(mean,std)]))
cifar_val = datasets.CIFAR10('.',train=False,download=False, \
                             transform=transforms.Compose([ \
                             transforms.ToTensor(), \
                             transforms.Normalize(mean_val, \
                             std_val)]))
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:11<00:00, 14695382.58it/s]


Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified

This dataset contains images of lots of other things besides birds and planes, so we'll need to filter out the images we're interested in.

# filter the images for birds and airplanes
label_map = {0:0,2:1}
class_names = ['airplane','bird']
cifar = [(img,label_map[label]) \
        for img,label in cifar  \
        if label in [0,2]]
cifar_val = [(img,label_map[label]) \
        for img,label in cifar_val \
        if label in [0,2]]

Next we'll define our neural net architecture.

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # convert 3 RGB channels into 16 channels
        # we're starting with 32x32 pixels
        self.conv1 = nn.Conv2d(3,16,kernel_size=3,padding=1)
        self.act1  = nn.Tanh()
        # pool and reduce images to 16x16 pixels
        self.pool1 = nn.MaxPool2d(2)
        # convert 16 channels into 8 channels
        self.conv2 = nn.Conv2d(16,8,kernel_size=3,padding=1)
        self.act2  = nn.Tanh()
        # pool and reduce the images to 8x8 pixels
        # we now have 8 channels of 8x8 images
        self.pool2 = nn.MaxPool2d(2)
        # fully connected layers follow
        self.fc1   = nn.Linear(8*8*8,32)
        self.act3  = nn.Tanh()
        self.fc2   = nn.Linear(32,2)
        self.act4  = nn.LogSoftmax(dim=1)

    def forward(self,x):
        out = self.pool1(self.act1(self.conv1(x)))
        out = self.pool2(self.act2(self.conv2(out)))
        out = out.view(-1,8*8*8)
        out = self.act3(self.fc1(out))
        out = self.act4(self.fc2(out))
        return out

It's time to train the model; we'll do it in randomized batches.

num_out = 2
num_input  = 512
num_hidden = 512
train_loader = torch.utils.data.DataLoader(cifar,batch_size=64, \
                                           shuffle=True)
model = Net()
loss_func = nn.NLLLoss()
learning_rate = 0.01
optimizer  = optim.SGD(model.parameters(),learning_rate)
err_tol    = 0.1
max_epoch  = 200
epoch      = 0
err = 1.e5
while ((epoch<max_epoch)&(err>err_tol)):
    err = 0.0
    for imgs, labels in train_loader:
        batch_size = imgs.shape[0]
        outputs = model(imgs)
        loss = loss_func(outputs,labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        err += loss.item()

    err /= len(train_loader)
    epoch += 1
    if (epoch%10==0):
        print("epoch: %d, loss: %f" % (epoch, err))
epoch: 10, loss: 0.348339
epoch: 20, loss: 0.301792
epoch: 30, loss: 0.271439
epoch: 40, loss: 0.247062
epoch: 50, loss: 0.227016
epoch: 60, loss: 0.210078
epoch: 70, loss: 0.192694
epoch: 80, loss: 0.175032
epoch: 90, loss: 0.160963
epoch: 100, loss: 0.146820
epoch: 110, loss: 0.134589
epoch: 120, loss: 0.125042
epoch: 130, loss: 0.113120
epoch: 140, loss: 0.102200

As usual, we'd like to know how well our model performs out of sample.

# check out-of-sample accuracy
num_correct = 0
for img, label in cifar_val:
    label_true = torch.tensor(label)
    out = model(img)
    label_pred = torch.exp(out).argmax()

    if (label_true==label_pred):
        num_correct += 1

percent = (num_correct/len(cifar_val))*100.0
print("percent correct OOS: ",percent)
percent correct OOS:  89.5

Given the low image quality and the size of the neural net, 90% doesn't seem too bad!

Discuss on Twitter